{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0TD5ZrvEMbhZ"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\").\n",
        "\n",
        "# Pix2Pix: An example with tf.keras and eager\n",
        "\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
        "\u003ca target=\"_blank\"  href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\n",
        "    \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e  \n",
        "\u003c/td\u003e\u003ctd\u003e\n",
        "\u003ca target=\"_blank\"  href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ITZuApL56Mny"
      },
      "source": [
        "This notebook demonstrates image to image translation using conditional GAN's, as described in [Image-to-Image Translation with Conditional Adversarial Networks](https://arxiv.org/abs/1611.07004). Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings. We use [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager) to achieve this.\n",
        "\n",
        "In example, we will use the [CMP Facade Database](http://cmp.felk.cvut.cz/~tylecr1/facade/), helpfully provided by the [Center for Machine Perception](http://cmp.felk.cvut.cz/) at the [Czech Technical University in Prague](https://www.cvut.cz/). To keep our example short, we will use a preprocessed [copy](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/) of this dataset, created by the authors of the [paper](https://arxiv.org/abs/1611.07004) above.\n",
        "\n",
        "Each epoch takes around 58 seconds on a single P100 GPU.\n",
        "\n",
        "Below is the output generated after training the model for 200 epochs.\n",
        "\n",
        "\n",
        "![sample output_1](https://www.tensorflow.org/images/gan/pix2pix_1.png)\n",
        "![sample output_2](https://www.tensorflow.org/images/gan/pix2pix_2.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "e1_Y75QXJS6h"
      },
      "source": [
        "## Import TensorFlow and enable eager execution"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YfIk2es3hJEd"
      },
      "outputs": [],
      "source": [
        "# Import TensorFlow \u003e= 1.10 and enable eager execution\n",
        "import tensorflow as tf\n",
        "tf.enable_eager_execution()\n",
        "\n",
        "import os\n",
        "import time\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "import PIL\n",
        "from IPython.display import clear_output"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "iYn4MdZnKCey"
      },
      "source": [
        "## Load the dataset\n",
        "\n",
        "You can download this dataset and similar datasets from [here](https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets). As mentioned in the [paper](https://arxiv.org/abs/1611.07004) we apply random jittering and mirroring to the training dataset.\n",
        "* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`\n",
        "* In random mirroring, the image is randomly flipped horizontally i.e left to right."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Kn-k8kTXuAlv"
      },
      "outputs": [],
      "source": [
        "path_to_zip = tf.keras.utils.get_file('facades.tar.gz',\n",
        "                                      cache_subdir=os.path.abspath('.'),\n",
        "                                      origin='https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz', \n",
        "                                      extract=True)\n",
        "\n",
        "PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2CbTEt448b4R"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = 400\n",
        "BATCH_SIZE = 1\n",
        "IMG_WIDTH = 256\n",
        "IMG_HEIGHT = 256"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tyaP4hLJ8b4W"
      },
      "outputs": [],
      "source": [
        "def load_image(image_file, is_train):\n",
        "  image = tf.read_file(image_file)\n",
        "  image = tf.image.decode_jpeg(image)\n",
        "\n",
        "  w = tf.shape(image)[1]\n",
        "\n",
        "  w = w // 2\n",
        "  real_image = image[:, :w, :]\n",
        "  input_image = image[:, w:, :]\n",
        "\n",
        "  input_image = tf.cast(input_image, tf.float32)\n",
        "  real_image = tf.cast(real_image, tf.float32)\n",
        "\n",
        "  if is_train:\n",
        "    # random jittering\n",
        "    \n",
        "    # resizing to 286 x 286 x 3\n",
        "    input_image = tf.image.resize_images(input_image, [286, 286], \n",
        "                                        align_corners=True, \n",
        "                                        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "    real_image = tf.image.resize_images(real_image, [286, 286], \n",
        "                                        align_corners=True, \n",
        "                                        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "    \n",
        "    # randomly cropping to 256 x 256 x 3\n",
        "    stacked_image = tf.stack([input_image, real_image], axis=0)\n",
        "    cropped_image = tf.random_crop(stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])\n",
        "    input_image, real_image = cropped_image[0], cropped_image[1]\n",
        "\n",
        "    if np.random.random() \u003e 0.5:\n",
        "      # random mirroring\n",
        "      input_image = tf.image.flip_left_right(input_image)\n",
        "      real_image = tf.image.flip_left_right(real_image)\n",
        "  else:\n",
        "    input_image = tf.image.resize_images(input_image, size=[IMG_HEIGHT, IMG_WIDTH], \n",
        "                                         align_corners=True, method=2)\n",
        "    real_image = tf.image.resize_images(real_image, size=[IMG_HEIGHT, IMG_WIDTH], \n",
        "                                        align_corners=True, method=2)\n",
        "  \n",
        "  # normalizing the images to [-1, 1]\n",
        "  input_image = (input_image / 127.5) - 1\n",
        "  real_image = (real_image / 127.5) - 1\n",
        "\n",
        "  return input_image, real_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "PIGN6ouoQxt3"
      },
      "source": [
        "## Use tf.data to create batches, map(do preprocessing) and shuffle the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "SQHmYSmk8b4b"
      },
      "outputs": [],
      "source": [
        "train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')\n",
        "train_dataset = train_dataset.shuffle(BUFFER_SIZE)\n",
        "train_dataset = train_dataset.map(lambda x: load_image(x, True))\n",
        "train_dataset = train_dataset.batch(1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MS9J0yA58b4g"
      },
      "outputs": [],
      "source": [
        "test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')\n",
        "test_dataset = test_dataset.map(lambda x: load_image(x, False))\n",
        "test_dataset = test_dataset.batch(1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "THY-sZMiQ4UV"
      },
      "source": [
        "## Write the generator and discriminator models\n",
        "\n",
        "* **Generator** \n",
        "  * The architecture of generator is a modified U-Net.\n",
        "  * Each block in the encoder is (Conv -\u003e Batchnorm -\u003e Leaky ReLU)\n",
        "  * Each block in the decoder is (Transposed Conv -\u003e Batchnorm -\u003e Dropout(applied to the first 3 blocks) -\u003e ReLU)\n",
        "  * There are skip connections between the encoder and decoder (as in U-Net).\n",
        "  \n",
        "* **Discriminator**\n",
        "  * The Discriminator is a PatchGAN.\n",
        "  * Each block in the discriminator is (Conv -\u003e BatchNorm -\u003e Leaky ReLU)\n",
        "  * The shape of the output after the last layer is (batch_size, 30, 30, 1)\n",
        "  * Each 30x30 patch of the output classifies a 70x70 portion of the input image (such an architecture is called a PatchGAN).\n",
        "  * Discriminator receives 2 inputs.\n",
        "    * Input image and the target image, which it should classify as real.\n",
        "    * Input image and the generated image (output of generator), which it should classify as fake. \n",
        "    * We concatenate these 2 inputs together in the code (`tf.concat([inp, tar], axis=-1)`)\n",
        "\n",
        "* Shape of the input travelling through the generator and the discriminator is in the comments in the code.\n",
        "\n",
        "To learn more about the architecture and the hyperparameters you can refer the [paper](https://arxiv.org/abs/1611.07004).\n",
        "    "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tqqvWxlw8b4l"
      },
      "outputs": [],
      "source": [
        "OUTPUT_CHANNELS = 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lFPI4Nu-8b4q"
      },
      "outputs": [],
      "source": [
        "class Downsample(tf.keras.Model):\n",
        "    \n",
        "  def __init__(self, filters, size, apply_batchnorm=True):\n",
        "    super(Downsample, self).__init__()\n",
        "    self.apply_batchnorm = apply_batchnorm\n",
        "    initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "    self.conv1 = tf.keras.layers.Conv2D(filters, \n",
        "                                        (size, size), \n",
        "                                        strides=2, \n",
        "                                        padding='same',\n",
        "                                        kernel_initializer=initializer,\n",
        "                                        use_bias=False)\n",
        "    if self.apply_batchnorm:\n",
        "        self.batchnorm = tf.keras.layers.BatchNormalization()\n",
        "  \n",
        "  def call(self, x, training):\n",
        "    x = self.conv1(x)\n",
        "    if self.apply_batchnorm:\n",
        "        x = self.batchnorm(x, training=training)\n",
        "    x = tf.nn.leaky_relu(x)\n",
        "    return x \n",
        "\n",
        "\n",
        "class Upsample(tf.keras.Model):\n",
        "    \n",
        "  def __init__(self, filters, size, apply_dropout=False):\n",
        "    super(Upsample, self).__init__()\n",
        "    self.apply_dropout = apply_dropout\n",
        "    initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "    self.up_conv = tf.keras.layers.Conv2DTranspose(filters, \n",
        "                                                   (size, size), \n",
        "                                                   strides=2, \n",
        "                                                   padding='same',\n",
        "                                                   kernel_initializer=initializer,\n",
        "                                                   use_bias=False)\n",
        "    self.batchnorm = tf.keras.layers.BatchNormalization()\n",
        "    if self.apply_dropout:\n",
        "        self.dropout = tf.keras.layers.Dropout(0.5)\n",
        "\n",
        "  def call(self, x1, x2, training):\n",
        "    x = self.up_conv(x1)\n",
        "    x = self.batchnorm(x, training=training)\n",
        "    if self.apply_dropout:\n",
        "        x = self.dropout(x, training=training)\n",
        "    x = tf.nn.relu(x)\n",
        "    x = tf.concat([x, x2], axis=-1)\n",
        "    return x\n",
        "\n",
        "\n",
        "class Generator(tf.keras.Model):\n",
        "    \n",
        "  def __init__(self):\n",
        "    super(Generator, self).__init__()\n",
        "    initializer = tf.random_normal_initializer(0., 0.02)\n",
        "    \n",
        "    self.down1 = Downsample(64, 4, apply_batchnorm=False)\n",
        "    self.down2 = Downsample(128, 4)\n",
        "    self.down3 = Downsample(256, 4)\n",
        "    self.down4 = Downsample(512, 4)\n",
        "    self.down5 = Downsample(512, 4)\n",
        "    self.down6 = Downsample(512, 4)\n",
        "    self.down7 = Downsample(512, 4)\n",
        "    self.down8 = Downsample(512, 4)\n",
        "\n",
        "    self.up1 = Upsample(512, 4, apply_dropout=True)\n",
        "    self.up2 = Upsample(512, 4, apply_dropout=True)\n",
        "    self.up3 = Upsample(512, 4, apply_dropout=True)\n",
        "    self.up4 = Upsample(512, 4)\n",
        "    self.up5 = Upsample(256, 4)\n",
        "    self.up6 = Upsample(128, 4)\n",
        "    self.up7 = Upsample(64, 4)\n",
        "\n",
        "    self.last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, \n",
        "                                                (4, 4), \n",
        "                                                strides=2, \n",
        "                                                padding='same',\n",
        "                                                kernel_initializer=initializer)\n",
        "  \n",
        "  @tf.contrib.eager.defun\n",
        "  def call(self, x, training):\n",
        "    # x shape == (bs, 256, 256, 3)    \n",
        "    x1 = self.down1(x, training=training) # (bs, 128, 128, 64)\n",
        "    x2 = self.down2(x1, training=training) # (bs, 64, 64, 128)\n",
        "    x3 = self.down3(x2, training=training) # (bs, 32, 32, 256)\n",
        "    x4 = self.down4(x3, training=training) # (bs, 16, 16, 512)\n",
        "    x5 = self.down5(x4, training=training) # (bs, 8, 8, 512)\n",
        "    x6 = self.down6(x5, training=training) # (bs, 4, 4, 512)\n",
        "    x7 = self.down7(x6, training=training) # (bs, 2, 2, 512)\n",
        "    x8 = self.down8(x7, training=training) # (bs, 1, 1, 512)\n",
        "\n",
        "    x9 = self.up1(x8, x7, training=training) # (bs, 2, 2, 1024)\n",
        "    x10 = self.up2(x9, x6, training=training) # (bs, 4, 4, 1024)\n",
        "    x11 = self.up3(x10, x5, training=training) # (bs, 8, 8, 1024)\n",
        "    x12 = self.up4(x11, x4, training=training) # (bs, 16, 16, 1024)\n",
        "    x13 = self.up5(x12, x3, training=training) # (bs, 32, 32, 512)\n",
        "    x14 = self.up6(x13, x2, training=training) # (bs, 64, 64, 256)\n",
        "    x15 = self.up7(x14, x1, training=training) # (bs, 128, 128, 128)\n",
        "\n",
        "    x16 = self.last(x15) # (bs, 256, 256, 3)\n",
        "    x16 = tf.nn.tanh(x16)\n",
        "\n",
        "    return x16"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ll6aNeQx8b4v"
      },
      "outputs": [],
      "source": [
        "class DiscDownsample(tf.keras.Model):\n",
        "    \n",
        "  def __init__(self, filters, size, apply_batchnorm=True):\n",
        "    super(DiscDownsample, self).__init__()\n",
        "    self.apply_batchnorm = apply_batchnorm\n",
        "    initializer = tf.random_normal_initializer(0., 0.02)\n",
        "\n",
        "    self.conv1 = tf.keras.layers.Conv2D(filters, \n",
        "                                        (size, size), \n",
        "                                        strides=2, \n",
        "                                        padding='same',\n",
        "                                        kernel_initializer=initializer,\n",
        "                                        use_bias=False)\n",
        "    if self.apply_batchnorm:\n",
        "        self.batchnorm = tf.keras.layers.BatchNormalization()\n",
        "  \n",
        "  def call(self, x, training):\n",
        "    x = self.conv1(x)\n",
        "    if self.apply_batchnorm:\n",
        "        x = self.batchnorm(x, training=training)\n",
        "    x = tf.nn.leaky_relu(x)\n",
        "    return x \n",
        "\n",
        "class Discriminator(tf.keras.Model):\n",
        "    \n",
        "  def __init__(self):\n",
        "    super(Discriminator, self).__init__()\n",
        "    initializer = tf.random_normal_initializer(0., 0.02)\n",
        "    \n",
        "    self.down1 = DiscDownsample(64, 4, False)\n",
        "    self.down2 = DiscDownsample(128, 4)\n",
        "    self.down3 = DiscDownsample(256, 4)\n",
        "    \n",
        "    # we are zero padding here with 1 because we need our shape to \n",
        "    # go from (batch_size, 32, 32, 256) to (batch_size, 31, 31, 512)\n",
        "    self.zero_pad1 = tf.keras.layers.ZeroPadding2D()\n",
        "    self.conv = tf.keras.layers.Conv2D(512, \n",
        "                                       (4, 4), \n",
        "                                       strides=1, \n",
        "                                       kernel_initializer=initializer, \n",
        "                                       use_bias=False)\n",
        "    self.batchnorm1 = tf.keras.layers.BatchNormalization()\n",
        "    \n",
        "    # shape change from (batch_size, 31, 31, 512) to (batch_size, 30, 30, 1)\n",
        "    self.zero_pad2 = tf.keras.layers.ZeroPadding2D()\n",
        "    self.last = tf.keras.layers.Conv2D(1, \n",
        "                                       (4, 4), \n",
        "                                       strides=1,\n",
        "                                       kernel_initializer=initializer)\n",
        "  \n",
        "  @tf.contrib.eager.defun\n",
        "  def call(self, inp, tar, training):\n",
        "    # concatenating the input and the target\n",
        "    x = tf.concat([inp, tar], axis=-1) # (bs, 256, 256, channels*2)\n",
        "    x = self.down1(x, training=training) # (bs, 128, 128, 64)\n",
        "    x = self.down2(x, training=training) # (bs, 64, 64, 128)\n",
        "    x = self.down3(x, training=training) # (bs, 32, 32, 256)\n",
        "\n",
        "    x = self.zero_pad1(x) # (bs, 34, 34, 256)\n",
        "    x = self.conv(x)      # (bs, 31, 31, 512)\n",
        "    x = self.batchnorm1(x, training=training)\n",
        "    x = tf.nn.leaky_relu(x)\n",
        "    \n",
        "    x = self.zero_pad2(x) # (bs, 33, 33, 512)\n",
        "    # don't add a sigmoid activation here since\n",
        "    # the loss function expects raw logits.\n",
        "    x = self.last(x)      # (bs, 30, 30, 1)\n",
        "\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gDkA05NE6QMs"
      },
      "outputs": [],
      "source": [
        "# The call function of Generator and Discriminator have been decorated\n",
        "# with tf.contrib.eager.defun()\n",
        "# We get a performance speedup if defun is used (~25 seconds per epoch)\n",
        "generator = Generator()\n",
        "discriminator = Discriminator()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0FMYgY_mPfTi"
      },
      "source": [
        "## Define the loss functions and the optimizer\n",
        "\n",
        "* **Discriminator loss**\n",
        "  * The discriminator loss function takes 2 inputs; **real images, generated images**\n",
        "  * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones(since these are the real images)**\n",
        "  * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros(since these are the fake images)**\n",
        "  * Then the total_loss is the sum of real_loss and the generated_loss\n",
        "  \n",
        "* **Generator loss**\n",
        "  * It is a sigmoid cross entropy loss of the generated images and an **array of ones**.\n",
        "  * The [paper](https://arxiv.org/abs/1611.07004) also includes L1 loss which is MAE (mean absolute error) between the generated image and the target image.\n",
        "  * This allows the generated image to become structurally similar to the target image.\n",
        "  * The formula to calculate the total generator loss = gan_loss + LAMBDA * l1_loss, where LAMBDA = 100. This value was decided by the authors of the [paper](https://arxiv.org/abs/1611.07004)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cyhxTuvJyIHV"
      },
      "outputs": [],
      "source": [
        "LAMBDA = 100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wkMNfBWlT-PV"
      },
      "outputs": [],
      "source": [
        "def discriminator_loss(disc_real_output, disc_generated_output):\n",
        "  real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_real_output), \n",
        "                                              logits = disc_real_output)\n",
        "  generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.zeros_like(disc_generated_output), \n",
        "                                                   logits = disc_generated_output)\n",
        "\n",
        "  total_disc_loss = real_loss + generated_loss\n",
        "\n",
        "  return total_disc_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "90BIcCKcDMxz"
      },
      "outputs": [],
      "source": [
        "def generator_loss(disc_generated_output, gen_output, target):\n",
        "  gan_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = tf.ones_like(disc_generated_output),\n",
        "                                             logits = disc_generated_output) \n",
        "  # mean absolute error\n",
        "  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))\n",
        "\n",
        "  total_gen_loss = gan_loss + (LAMBDA * l1_loss)\n",
        "\n",
        "  return total_gen_loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "iWCn_PVdEJZ7"
      },
      "outputs": [],
      "source": [
        "generator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)\n",
        "discriminator_optimizer = tf.train.AdamOptimizer(2e-4, beta1=0.5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aKUZnDiqQrAh"
      },
      "source": [
        "## Checkpoints (Object-based saving)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "WJnftd5sQsv6"
      },
      "outputs": [],
      "source": [
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
        "checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,\n",
        "                                 discriminator_optimizer=discriminator_optimizer,\n",
        "                                 generator=generator,\n",
        "                                 discriminator=discriminator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Rw1fkAczTQYh"
      },
      "source": [
        "## Training\n",
        "\n",
        "* We start by iterating over the dataset\n",
        "* The generator gets the input image and we get a generated output.\n",
        "* The discriminator receives the input_image and the generated image as the first input. The second input is the input_image and the target_image.\n",
        "* Next, we calculate the generator and the discriminator loss.\n",
        "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables(inputs) and apply those to the optimizer.\n",
        "\n",
        "## Generate Images\n",
        "\n",
        "* After training, its time to generate some images!\n",
        "* We pass images from the test dataset to the generator.\n",
        "* The generator will then translate the input image into the output we expect.\n",
        "* Last step is to plot the predictions and **voila!**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "NS2GWywBbAWo"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 200"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "RmdVsmvhPxyy"
      },
      "outputs": [],
      "source": [
        "def generate_images(model, test_input, tar):\n",
        "  # the training=True is intentional here since\n",
        "  # we want the batch statistics while running the model\n",
        "  # on the test dataset. If we use training=False, we will get \n",
        "  # the accumulated statistics learned from the training dataset\n",
        "  # (which we don't want)\n",
        "  prediction = model(test_input, training=True)\n",
        "  plt.figure(figsize=(15,15))\n",
        "\n",
        "  display_list = [test_input[0], tar[0], prediction[0]]\n",
        "  title = ['Input Image', 'Ground Truth', 'Predicted Image']\n",
        "\n",
        "  for i in range(3):\n",
        "    plt.subplot(1, 3, i+1)\n",
        "    plt.title(title[i])\n",
        "    # getting the pixel values between [0, 1] to plot it.\n",
        "    plt.imshow(display_list[i] * 0.5 + 0.5)\n",
        "    plt.axis('off')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2M7LmLtGEMQJ"
      },
      "outputs": [],
      "source": [
        "def train(dataset, epochs):  \n",
        "  for epoch in range(epochs):\n",
        "    start = time.time()\n",
        "\n",
        "    for input_image, target in dataset:\n",
        "\n",
        "      with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n",
        "        gen_output = generator(input_image, training=True)\n",
        "\n",
        "        disc_real_output = discriminator(input_image, target, training=True)\n",
        "        disc_generated_output = discriminator(input_image, gen_output, training=True)\n",
        "\n",
        "        gen_loss = generator_loss(disc_generated_output, gen_output, target)\n",
        "        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)\n",
        "\n",
        "      generator_gradients = gen_tape.gradient(gen_loss, \n",
        "                                              generator.variables)\n",
        "      discriminator_gradients = disc_tape.gradient(disc_loss, \n",
        "                                                   discriminator.variables)\n",
        "\n",
        "      generator_optimizer.apply_gradients(zip(generator_gradients, \n",
        "                                              generator.variables))\n",
        "      discriminator_optimizer.apply_gradients(zip(discriminator_gradients, \n",
        "                                                  discriminator.variables))\n",
        "\n",
        "    if epoch % 1 == 0:\n",
        "        clear_output(wait=True)\n",
        "        for inp, tar in test_dataset.take(1):\n",
        "          generate_images(generator, inp, tar)\n",
        "          \n",
        "    # saving (checkpoint) the model every 20 epochs\n",
        "    if (epoch + 1) % 20 == 0:\n",
        "      checkpoint.save(file_prefix = checkpoint_prefix)\n",
        "\n",
        "    print ('Time taken for epoch {} is {} sec\\n'.format(epoch + 1,\n",
        "                                                        time.time()-start))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "a1zZmKmvOH85"
      },
      "outputs": [],
      "source": [
        "train(train_dataset, EPOCHS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "kz80bY3aQ1VZ"
      },
      "source": [
        "## Restore the latest checkpoint and test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4t4x69adQ5xb"
      },
      "outputs": [],
      "source": [
        "# restoring the latest checkpoint in checkpoint_dir\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "1RGysMU_BZhx"
      },
      "source": [
        "## Testing on the entire test dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "KUgSnmy2nqSP"
      },
      "outputs": [],
      "source": [
        "# Run the trained model on the entire test dataset\n",
        "for inp, tar in test_dataset:\n",
        "  generate_images(generator, inp, tar)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3AJXOByaZVOf"
      },
      "outputs": [],
      "source": [
        ""
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "pix2pix_eager.ipynb",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp",
          "timestamp": 1527173385672
        }
      ],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
